package com.sixbro.data.elasticsearch.service.base;

import com.sixbro.data.elasticsearch.common.PageQuery;
import com.sixbro.data.elasticsearch.common.SortParam;
import com.sixbro.data.elasticsearch.util.CollectionUtils;
import lombok.extern.slf4j.Slf4j;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.sort.SortBuilders;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.annotation.Id;
import org.springframework.data.domain.*;
import org.springframework.data.elasticsearch.core.ElasticsearchRestTemplate;
import org.springframework.data.elasticsearch.core.RefreshPolicy;
import org.springframework.data.elasticsearch.core.SearchHit;
import org.springframework.data.elasticsearch.core.SearchHits;
import org.springframework.data.elasticsearch.core.document.Document;
import org.springframework.data.elasticsearch.core.query.BulkOptions;
import org.springframework.data.elasticsearch.core.query.NativeSearchQuery;
import org.springframework.data.elasticsearch.core.query.NativeSearchQueryBuilder;
import org.springframework.data.elasticsearch.core.query.UpdateQuery;
import org.springframework.data.elasticsearch.repository.ElasticsearchRepository;

import java.lang.reflect.Field;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;


/**
 * <p>
 *
 * </p>
 *
 * @author: Mr.Lu
 * @since: 2021-12-14 17:24
 */
@Slf4j
public abstract class ESBaseServiceImpl<T, ID, M extends ElasticsearchRepository<T, ID>> implements IESBaseService<T, ID> {

    @Autowired(required = false)
    public M repository;

    @Autowired
    public ElasticsearchRestTemplate elasticsearchRestTemplate;

    @Override
    public <S extends T> S save(S entity) {
        return repository.save(entity);
    }

    @Override
    public <S extends T> Iterable<S> saveAll(Iterable<S> entities) {
        return repository.saveAll(entities);
    }

    @Override
    public Optional<T> findById(ID id) {
        return repository.findById(id);
    }

    @Override
    public boolean existsById(ID id) {
        return repository.existsById(id);
    }

    @Override
    public Collection<T> findAll() {
        return list(new PageQuery());
    }

    @Override
    public Collection<T> findAllById(Collection<ID> ids) {
        return (Collection<T>) repository.findAllById(ids);
    }

    @Override
    public long count() {
        return repository.count();
    }

    @Override
    public void deleteById(ID id) {
        repository.deleteById(id);
    }

    @Override
    public void delete(T entity) {
        final List<Field> fields = getEntityAllField();

        AtomicInteger num = new AtomicInteger();
        // 构建过滤条件
        BoolQueryBuilder filter = buildFilterBoolQueryBuilder(fields, entity, num);

        if (num.intValue() < 1) {
            return;
        }

        // 构建查询条件
        NativeSearchQueryBuilder queryBuilder = new NativeSearchQueryBuilder();
        queryBuilder.withFilter(filter);
        // 执行删除
        elasticsearchRestTemplate.delete(queryBuilder.build(), getEntityClass());
    }

    @Override
    public void deleteAllById(Iterable<? extends ID> ids) {
        repository.deleteAllById(ids);
    }

    @Override
    public void deleteAll(Collection<? extends T> entities) {
        if (CollectionUtils.isEmpty(entities)) {
            return;
        }

        entities.forEach(this::delete);
    }

    @Override
    public void deleteAll() {
        repository.deleteAll();
    }

    @Override
    public Iterable<T> findAll(Sort sort) {
        return repository.findAll(sort);
    }

    @Override
    public Page<T> findAll(Pageable pageable) {
        return repository.findAll(pageable);
    }

    @Override
    public Page<T> searchSimilar(T entity, String[] fields, Pageable pageable) {
        return repository.searchSimilar(entity, fields, pageable);
    }

    @Override
    public <Q extends PageQuery> Page<T> search(Q query) {
        Long total = count(query);
        SearchHits<T> searchHits = commonSearch(query, true);
        if (searchHits.getTotalHits() > 0) {
            List<T> searchProductList = searchHits.stream().map(SearchHit::getContent).collect(Collectors.toList());
            return new PageImpl<>(searchProductList, PageRequest.of(query.getPage(), query.getSize()), total);
        }
        return new PageImpl<T>(new ArrayList<>(), PageRequest.of(query.getPage(), query.getSize()), total);
    }

    @Override
    public <Q extends PageQuery> Long count(Q query) {
        return commonSearch(query, false).getTotalHits();
    }

    @Override
    public <Q extends PageQuery> List<T> list(Q query) {
        SearchHits<T> searchHits = commonSearch(query, false);
        if (searchHits.getTotalHits() > 0) {
            return searchHits.stream().map(SearchHit::getContent).collect(Collectors.toList());
        }
        return null;
    }

    @Override
    public <Q extends PageQuery> List<T> list(Q query, String... columnName) {
        SearchHits<T> searchHits = commonSearch(query, false, columnName);
        if (searchHits.getTotalHits() > 0) {
            return searchHits.stream().map(SearchHit::getContent).collect(Collectors.toList());
        }

        return null;
    }

    @Override
    public <S extends T> void update(S entity) {
        commonUpdate(CollectionUtils.singleList(entity), false);
    }

    @Override
    public <S extends T> void updateAndFlush(S entity) {
        commonUpdate(CollectionUtils.singleList(entity), true);
    }

    @Override
    public <S extends T> void update(Collection<S> entities) {
        commonUpdate(entities, false);
    }

    @Override
    public <S extends T> void updateAndFlush(Collection<S> entities) {
        commonUpdate(entities, true);
    }


    @Override
    public <S extends T> void saveOrUpdate(Collection<S> entities) {
        Map<ID, Map<String, Object>> tempMap = idTempMap(entities);

        if (tempMap == null) {
            return;
        }

        List<ID> ids = new ArrayList<>(tempMap.keySet());

        Collection<T> records = findAllById(ids);

        if (CollectionUtils.isEmpty(records)) {
            saveAll(entities);
            return;
        }

        List<T> save = new CopyOnWriteArrayList<>();

        List<T> update = new CopyOnWriteArrayList<>();

        records.forEach(entity -> {
            Field[] declaredFields = entity.getClass().getDeclaredFields();
            for (Field field : declaredFields) {
                if (!field.isAnnotationPresent(Id.class)) {
                    continue;
                }

                ID id = (ID) doGetFieldValue(field, entity);

                Map<String, Object> map = tempMap.get(id);

                if (map == null) {
                    // save
                    save.add(entity);
                } else {
                    // update
                    update.add(entity);
                }
            }
        });

        if (CollectionUtils.isNotEmpty(save)) {
            saveAll(save);
        }

        if (CollectionUtils.isNotEmpty(update)) {
            commonUpdate(update, true);
        }
    }

    /**
     * 通用更新
     * @param entities 请求实体
     * @param flush    是否立即刷新
     * @param <S>      请求实体类型
     */
    private <S extends T> void commonUpdate(Collection<S> entities, Boolean flush) {
        Map<ID, Map<String, Object>> tempMap = idTempMap(entities);

        if (tempMap == null) {
            return;
        }

        List<UpdateQuery> queries = new CopyOnWriteArrayList<>();

        tempMap.forEach((id, params) -> {
            UpdateQuery build = UpdateQuery.builder(String.valueOf(id))
                    .withDocument(Document.from(params))
                    .build();
            queries.add(build);
        });

        if (flush) {
            // 立刻刷新，损害性能
            elasticsearchRestTemplate.bulkUpdate(
                    queries,
                    BulkOptions.builder().withRefreshPolicy(RefreshPolicy.IMMEDIATE).build(),
                    elasticsearchRestTemplate.getIndexCoordinatesFor(getEntityClass()));

        } else {
            // 不执行立刻刷新，损害性能
            elasticsearchRestTemplate.bulkUpdate(queries, getEntityClass());
        }
    }

    private <S extends T> Map<ID, Map<String, Object>> idTempMap(Collection<S> entities) {
        if (CollectionUtils.isEmpty(entities)) {
            return null;
        }

        final List<Field> fields = getEntityAllField();

        Map<ID, Map<String, Object>> tempMap = new ConcurrentHashMap<>();

        entities.forEach(entity -> buildIdMapParams(fields, entity, tempMap));

        if (CollectionUtils.isEmpty(tempMap)) {
            return null;
        }

        return tempMap;
    }

    private <S extends T> void buildIdMapParams(List<Field> fields, S entity, Map<ID, Map<String, Object>> tempMap) {
        // 用来存放参数
        Map<String, Object> params = new LinkedHashMap<>();

        for (Field field : fields) {
            Object o = doGetFieldValue(field, entity);

            if (o == null) {
                continue;
            }

            params.put(field.getName(), o);

            if (field.isAnnotationPresent(Id.class)) {
                // 主键ID
                tempMap.put((ID) o, params);
            }
        }
    }


    private BoolQueryBuilder buildFilterBoolQueryBuilder(List<Field> fields, T entity, AtomicInteger num) {
        // 查询构建器
        BoolQueryBuilder filter = QueryBuilders.boolQuery();

        for (Field field : fields) {
            Object obj = doGetFieldValue(field, entity);

            if (obj == null) {
                continue;
            }

            // 计数器统计数量+1
            num.incrementAndGet();

            filter.must(QueryBuilders.termQuery(field.getName(), obj));
        }

        return filter;
    }

    /**
     * 获取属性值
     * @param field  field对象
     * @param entity 实体类
     * @return 属性值
     */
    private Object doGetFieldValue(Field field, T entity) {
        field.setAccessible(true);
        // 一般属性
        Object o = null;
        try {
            o = field.get(entity);
        } catch (IllegalAccessException e) {
            log.error("获取属性异常", e);
        }

        return o;
    }


    /**
     * 通用查询
     * @param query 查询条件
     * @param <Q>   查询条件类型
     * @param page  是否需要分页
     * @return es响应对象
     */
    private <Q extends PageQuery> SearchHits<T> commonSearch(Q query, boolean page, String... columnName) {
        // 构建查询条件
        NativeSearchQueryBuilder queryBuilder = new NativeSearchQueryBuilder();
        // 查询构建器
        BoolQueryBuilder builder = QueryBuilders.boolQuery();
        // 构建过滤条件
        buildFilterCondition(builder, query);

        queryBuilder.withQuery(builder);

        List<SortParam> sorts = sortFields(query);

        if (CollectionUtils.isNotEmpty(sorts)) {
            for (SortParam sort : sorts) {
                queryBuilder.withSort(SortBuilders.fieldSort(sort.getFieldName()).order(sort.getOrder()));
            }
        }
        // 分页条件
        if (page) {
            queryBuilder.withPageable(PageRequest.of(query.getPage(), query.getSize()));
        }

        NativeSearchQuery nativeSearchQuery = queryBuilder.build();

        // 页面返回字段设置
        if (columnName != null && columnName.length > 0) {
            nativeSearchQuery.addFields(columnName);
        } else {
            nativeSearchQuery.addFields(returnFields());
        }

        // 使用ElasticsearchRestTemplate进行复杂查询
        return elasticsearchRestTemplate.search(nativeSearchQuery, this.getEntityClass());
    }

    /**
     * 构建过滤条件
     */
    public <Q extends PageQuery> void buildFilterCondition(BoolQueryBuilder filter, Q queryParam) {
        // 强转为实际请求对象
        Q query = (Q) queryParam;
        // 根据实际参数构造查询条件
        // eg:
        // 带分词匹配
        // filter.must(QueryBuilders.matchQuery("xxx", query.getXxx()));
        // 不分词匹配
        // filter.must(QueryBuilders.termQuery("xxx", query.getXxx()));
        // 范围匹配
        // filter.must(QueryBuilders.rangeQuery("createTime").gte(query.getCreateTime() + " 00:00:00"));
    }
}
